While our last example focused on using an MST to connect everything as cheaply as possible, the same structure can be used for the opposite goal: identifying distinct groups or clusters in a dataset.

  • This technique is powerful for tasks like image segmentation, where the goal is to group similar pixels together.
  • Each data point (or pixel) is a vertex $v$ in a graph $G$.
  • The weight $w(u, v)$ of an edge is a measure of dissimilarity between two points. For pixels, this could be the difference in color and brightness. A high weight means the points are very different.
  • We first compute the MST for the entire dataset. This connects all points with the minimum possible total dissimilarity.
  • The edges with the highest weights in the MST represent the weakest links between potential clusters. By removing these long edges, we can break the graph into separate, more tightly-knit components.
  • If we want to find $k$ clusters, we simply remove the $k-1$ most expensive edges from the MST.
Python: MST-based Clustering
1import numpy as np
2from scipy.spatial.distance import pdist, squareform
3from scipy.sparse.csgraph import minimum_spanning_tree
4
5# Define 2D data points with two clear clusters
6points = np.array([
7    [2, 3], [3, 4], [4, 3], [3, 2],  # Cluster 1
8    [9, 8], [10, 9], [11, 8], [10, 7] # Cluster 2
9])
10
11# 1. Create a complete graph where edge weights are Euclidean distances
12distance_matrix = squareform(pdist(points, 'euclidean'))
13
14# 2. Compute the Minimum Spanning Tree from the distance matrix
15mst_sparse = minimum_spanning_tree(distance_matrix)
16
17# 3. Find the longest edge in the MST to break for clustering
18rows, cols = mst_sparse.nonzero()
19weights = mst_sparse.data
20
21longest_edge_index = np.argmax(weights)
22
23u = rows[longest_edge_index]
24v = cols[longest_edge_index]
25max_weight = weights[longest_edge_index]
26
27print(f"To create 2 clusters, remove the 'weakest link' edge:")
28print(f"Edge ({u}, {v}) with weight {max_weight:.2f}")